// Copyright Project Harbor Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package scan

import (
	"context"

	"github.com/goharbor/harbor/src/lib"
	"github.com/goharbor/harbor/src/lib/log"
	"github.com/goharbor/harbor/src/lib/orm"
	"github.com/goharbor/harbor/src/lib/q"
)

func init() {
	orm.RegisterModel(new(VulnerabilityRecord), new(ReportVulnerabilityRecord))
}

// VulnerabilityRecordDao exposes the DAO layer contract to perform
// CRUD operations on vulnerability record objects
type VulnerabilityRecordDao interface {
	// Create creates a new vulnerability record
	Create(ctx context.Context, vr *VulnerabilityRecord) (int64, error)
	// Delete deletes a vulnerability record
	Delete(ctx context.Context, vr *VulnerabilityRecord) error
	// Update updates a vulnerability record
	Update(ctx context.Context, vr *VulnerabilityRecord, cols ...string) error
	// List lists the vulnerability records
	List(ctx context.Context, query *q.Query) ([]*VulnerabilityRecord, error)
	// InsertForReport inserts vulnerability records for a report
	InsertForReport(ctx context.Context, reportUUID string, vulnerabilityRecordIDs ...int64) error
	// GetForReport gets vulnerability records for a report
	GetForReport(ctx context.Context, reportUUID string) ([]*VulnerabilityRecord, error)
	// GetForScanner gets vulnerability records for a scanner
	GetForScanner(ctx context.Context, registrationUUID string) ([]*VulnerabilityRecord, error)
	// DeleteForScanner deletes vulnerability records for a scanner
	DeleteForScanner(ctx context.Context, registrationUUID string) (int64, error)
	// DeleteForReport deletes vulnerability records for a report
	DeleteForReport(ctx context.Context, reportUUID string) (int64, error)
	// DeleteForDigests deletes vulnerability records for a provided list of digests
	DeleteForDigests(ctx context.Context, digests ...string) (int64, error)
	// GetRecordIDsForScanner gets record ids of vulnerability records for a scanner
	GetRecordIDsForScanner(ctx context.Context, registrationUUID string) ([]int, error)
}

// NewVulnerabilityRecordDao returns a new dao to handle vulnerability data
func NewVulnerabilityRecordDao() VulnerabilityRecordDao {
	return &vulnerabilityRecordDao{}
}

type vulnerabilityRecordDao struct{}

// Create creates new vulnerability record.
func (v *vulnerabilityRecordDao) Create(ctx context.Context, vr *VulnerabilityRecord) (int64, error) {
	_, vrID, err := orm.ReadOrCreate(ctx, vr, "cve_id", "registration_uuid", "package", "package_version")

	return vrID, err
}

// Delete  deletes a vulnerability record
func (v *vulnerabilityRecordDao) Delete(ctx context.Context, vr *VulnerabilityRecord) error {
	o, err := orm.FromContext(ctx)
	if err != nil {
		return err
	}
	_, err = o.Delete(vr, "CVEID", "RegistrationUUID", "Package", "PackageVersion")

	return err
}

func (v *vulnerabilityRecordDao) Update(ctx context.Context, vr *VulnerabilityRecord, cols ...string) error {
	o, err := orm.FromContext(ctx)
	if err != nil {
		return err
	}
	_, err = o.Update(vr, cols...)

	return err
}

// List lists the vulnerability records with given query parameters.
// Keywords in query here will be enforced with `exact` way.
// If the registration ID (which = the scanner ID is not specified), the results
// would contain duplicate records for a CVE depending upon the number of registered
// scanners which individually store data about the CVE. In such cases, it is the
// responsibility of the calling code to de-duplicate the CVE records or bucket them
// per registered scanner
func (v *vulnerabilityRecordDao) List(ctx context.Context, query *q.Query) ([]*VulnerabilityRecord, error) {
	qs, err := orm.QuerySetter(ctx, &VulnerabilityRecord{}, query)
	if err != nil {
		return nil, err
	}

	l := make([]*VulnerabilityRecord, 0)
	_, err = qs.All(&l)

	return l, err
}

// InsertForReport inserts the vulnerability records in the context of scan report
func (v *vulnerabilityRecordDao) InsertForReport(ctx context.Context, reportUUID string, vulnerabilityRecordIDs ...int64) error {
	if len(vulnerabilityRecordIDs) == 0 {
		return nil
	}

	s := lib.Set{}

	var records []*ReportVulnerabilityRecord
	for _, vulnerabilityRecordID := range vulnerabilityRecordIDs {
		if s.Exists(vulnerabilityRecordID) {
			continue
		}

		s.Add(vulnerabilityRecordID)

		records = append(records, &ReportVulnerabilityRecord{
			Report:       reportUUID,
			VulnRecordID: vulnerabilityRecordID,
		})
	}

	h := func(ctx context.Context) error {
		o, err := orm.FromContext(ctx)
		if err != nil {
			return err
		}

		_, err = o.InsertMulti(100, records)
		return err
	}

	if err := orm.WithTransaction(h)(orm.SetTransactionOpNameToContext(ctx, "tx-insert-for-report")); err != nil {
		fields := log.Fields{
			"error":  err,
			"report": reportUUID,
		}
		log.G(ctx).WithFields(fields).Warningf("Could not associate vulnerability record to the report")

		return err
	}

	return nil
}

// DeleteForReport deletes the vulnerability records for a single report
func (v *vulnerabilityRecordDao) DeleteForReport(ctx context.Context, reportUUID string) (int64, error) {
	o, err := orm.FromContext(ctx)
	if err != nil {
		return 0, err
	}
	delCount, err := o.Delete(&ReportVulnerabilityRecord{Report: reportUUID}, "report_uuid")
	return delCount, err
}

// GetForReport gets all the vulnerability records for a report based on UUID
func (v *vulnerabilityRecordDao) GetForReport(ctx context.Context, reportUUID string) ([]*VulnerabilityRecord, error) {
	vulnRecs := make([]*VulnerabilityRecord, 0)
	o, err := orm.FromContext(ctx)
	if err != nil {
		return nil, err
	}
	query := `select vulnerability_record.* from vulnerability_record
			  inner join report_vulnerability_record on
			  vulnerability_record.id = report_vulnerability_record.vuln_record_id and report_vulnerability_record.report_uuid=?`
	_, err = o.Raw(query, reportUUID).QueryRows(&vulnRecs)
	return vulnRecs, err
}

// GetForScanner gets all the vulnerability records known to a scanner
// identified by registrationUUID
func (v *vulnerabilityRecordDao) GetForScanner(ctx context.Context, registrationUUID string) ([]*VulnerabilityRecord, error) {
	var vulnRecs []*VulnerabilityRecord
	o, err := orm.FromContext(ctx)
	if err != nil {
		return nil, err
	}
	vulRec := new(VulnerabilityRecord)
	qs := o.QueryTable(vulRec)
	_, err = qs.Filter("registration_uuid", registrationUUID).All(&vulnRecs)
	if err != nil {
		return nil, err
	}
	return vulnRecs, nil
}

// DeleteForScanner deletes all the vulnerability records for a given scanner
// identified by registrationUUID
func (v *vulnerabilityRecordDao) DeleteForScanner(ctx context.Context, registrationUUID string) (int64, error) {
	o, err := orm.FromContext(ctx)
	if err != nil {
		return 0, err
	}
	vulnRec := new(VulnerabilityRecord)
	vulnRec.RegistrationUUID = registrationUUID
	return o.Delete(vulnRec, "registration_uuid")
}

// DeleteForDigests deletes the report vulnerability record mappings for the provided
// set of digests
func (v *vulnerabilityRecordDao) DeleteForDigests(ctx context.Context, digests ...string) (int64, error) {
	reportDao := New()

	ol := q.OrList{}
	for _, digest := range digests {
		ol.Values = append(ol.Values, digest)
	}
	reports, err := reportDao.List(ctx, &q.Query{Keywords: q.KeyWords{"digest": &ol}})
	if err != nil {
		return 0, err
	}
	numRowsDeleted := int64(0)
	for _, report := range reports {
		delCount, err := v.DeleteForReport(ctx, report.UUID)
		if err != nil {
			return 0, err
		}
		numRowsDeleted += delCount
	}
	return numRowsDeleted, nil
}

// GetRecordIDsForScanner retrieves the internal Ids of the vulnerability records for a given scanner
// identified by registrationUUID
func (v *vulnerabilityRecordDao) GetRecordIDsForScanner(ctx context.Context, registrationUUID string) ([]int, error) {
	vulnRecordIDs := make([]int, 0)
	o, err := orm.FromContext(ctx)
	if err != nil {
		return nil, err
	}
	_, err = o.Raw("select id from vulnerability_record where registration_uuid = ?", registrationUUID).QueryRows(&vulnRecordIDs)
	if err != nil {
		return vulnRecordIDs, err
	}
	return vulnRecordIDs, err
}
